# Copyright 2025 The TensorStore Authors
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Based on the @rules_python//python/private/local_runtime_repo.bzl
#
"""Repository rule for Python local tool configuration.

`local_python_runtime` depends on the following environment variables:

  * `PYTHON_BIN_PATH`: location of python binary.
  * `TENSORSTORE_PYTHON_CONFIG_REPO`: location of remote python configuration repository.

There are three essential changes here compared to @rules_python/python/local_toolchains
and @rules_python/python/private/local_runtime_repo.bzl

1. The environment variable `PYTHON_BIN_PATH` overrides interpreter_path resolution;
   The environment variable `TENSORSTORE_PYTHON_CONFIG_REPO` overrides local configuration.

2. It happens before the @rules_python is loaded, so it cannot reference rules python.
"""

load(
    "//bazel/repo_rules:py_utils.bzl",
    "py_utils",
    _ENV_VARS = "ENV_VARS",
)
load(
    "//bazel/repo_rules:repo_utils.bzl",
    "repo_utils",
)

_TENSORSTORE_PYTHON_CONFIG_REPO = "TENSORSTORE_PYTHON_CONFIG_REPO"

_TOOLCHAIN_IMPL_TEMPLATE = """\
# Generated by bazel/repo_rules/local_python_runtime.bzl

package(
    default_visibility = ["//visibility:public"],
)

load(
    "@tensorstore//bazel/repo_rules:local_runtime_repo_setup.bzl",
    "define_local_runtime_toolchain_impl",
)

define_local_runtime_toolchain_impl(
    name = "local_runtime",
    major = "{major}",
    minor = "{minor}",
    micro = "{micro}",
    abi_flags = "{abi_flags}",
    os = "{os}",
    implementation_name = "{implementation_name}",
    interpreter_path = "{interpreter_path}",
    interface_library = {interface_library},
    libraries = {libraries},
    defines = {defines},
    abi3_interface_library = {abi3_interface_library},
    abi3_libraries = {abi3_libraries},
    additional_dlls = {additional_dlls},
)
"""

def _expand_incompatible_template():
    return _TOOLCHAIN_IMPL_TEMPLATE.format(
        major = "0",
        minor = "0",
        micro = "0",
        abi_flags = "",
        os = "@platforms//:incompatible",
        implementation_name = "incompatible",
        interpreter_path = "/incompatible",
        interface_library = "None",
        libraries = "[]",
        defines = "[]",
        abi3_interface_library = "None",
        abi3_libraries = "[]",
        additional_dlls = "[]",
    )

def _symlink_libraries(rctx, logger, libraries, shlib_suffix):
    """Symlinks the shared libraries into the lib/ directory.

    Args:
        rctx: A repository_ctx object
        logger: A repo_utils.logger object
        libraries: paths to libraries to attempt to symlink.
        shlib_suffix: Optional. Ensure that the generated symlinks end with this suffix.
    Returns:
        A single library path linked by the action.

    Individual files are symlinked instead of the whole directory because
    shared_lib_dirs contains multiple search paths for the shared libraries,
    and the python files may be missing from any of those directories, and
    any of those directories may include non-python runtime libraries,
    as would be the case if LIBDIR were, for example, /usr/lib.
    """
    result = []
    for source in libraries:
        origin = rctx.path(source)
        if not origin.exists:
            # The reported names don't always exist; it depends on the particulars
            # of the runtime installation.
            continue
        if shlib_suffix and not origin.basename.endswith(shlib_suffix):
            target = "lib/{}{}".format(origin.basename, shlib_suffix)
        else:
            target = "lib/{}".format(origin.basename)
        logger.debug(lambda: "Symlinking {} to {}".format(origin, target))
        repo_utils.watch(rctx, origin)
        rctx.symlink(origin, target)
        result.append(target)
    return result

def _local_python_repo_impl(rctx):
    """Implementation of the local_python_repo repository rule."""
    if _TENSORSTORE_PYTHON_CONFIG_REPO in rctx.os.environ:
        remote_config_repo = rctx.os.environ[_TENSORSTORE_PYTHON_CONFIG_REPO]
        rctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
        return

    # Otherwise creates the repository containing files set up to build with Python."""
    logger = repo_utils.logger(rctx)
    on_failure = rctx.attr.on_failure

    def _emit_log(msg):
        if on_failure == "fail":
            logger.fail(msg)
        elif on_failure == "warn":
            logger.warn(msg)
        else:
            logger.info(msg)

    result = py_utils.get_python_interpreter(rctx, rctx.attr.interpreter_path)
    if not result.resolved_path:
        _emit_log(lambda: "python interpreter not found: {}".format(result.describe_failure()))

        # else, on_failure must be skip
        rctx.file("BUILD.bazel", _expand_incompatible_template())
        return
    else:
        interpreter_path = result.resolved_path

    result = py_utils.get_python_info(
        rctx,
        interpreter_path = interpreter_path,
        logger = logger,
    )
    if not result.info:
        _emit_log(result.describe_failure)

        # else, on_failure must be skip
        rctx.file("BUILD.bazel", _expand_incompatible_template())
        return
    else:
        info = result.info

    # We use base_executable because we want the path within a Python
    # installation directory ("PYTHONHOME"). The problems with sys.executable
    # are:
    # * If we're in an activated venv, then we don't want the venv's
    #   `bin/python3` path to be used -- it isn't an actual Python installation.
    # * If sys.executable is a wrapper (e.g. pyenv), then (1) it may not be
    #   located within an actual Python installation directory, and (2) it
    #   can interfer with Python recognizing when it's within a venv.
    #
    # In some cases, it may be a symlink (usually e.g. `python3->python3.12`),
    # but we don't realpath() it to respect what it has decided is the
    # appropriate path.
    interpreter_path = info["base_executable"]
    logger.info(lambda: "Found external Python interpreter {}".format(interpreter_path))

    # NOTE: Keep in sync with recursive glob in define_local_runtime_toolchain_impl
    include_path = rctx.path(info["include"])

    # The reported include path may not exist, and watching a non-existant
    # path is an error. Silently skip, since includes are only necessary
    # if C extensions are built.
    if include_path.exists and include_path.is_dir:
        repo_utils.watch_tree(rctx, include_path)
    else:
        pass

    # The cc_library.includes values have to be non-absolute paths, otherwise
    # the toolchain will give an error. Work around this error by making them
    # appear as part of this repo.
    logger.debug(lambda: "Symlinking {} to include".format(include_path))
    rctx.symlink(include_path, "include")

    rctx.report_progress("Symlinking external Python shared libraries")

    interface_library = None
    if info["dynamic_libraries"]:
        libraries = _symlink_libraries(rctx, logger, info["dynamic_libraries"][:1], info["shlib_suffix"])
        symlinked = _symlink_libraries(rctx, logger, info["interface_libraries"][:1], None)
        if symlinked:
            interface_library = symlinked[0]
    else:
        libraries = _symlink_libraries(rctx, logger, info["static_libraries"], None)
        if not libraries:
            logger.info("No python libraries found.")

    abi3_interface_library = None
    if info["abi_dynamic_libraries"]:
        abi3_libraries = _symlink_libraries(rctx, logger, info["abi_dynamic_libraries"][:1], info["shlib_suffix"])
        symlinked = _symlink_libraries(rctx, logger, info["abi_interface_libraries"][:1], None)
        if symlinked:
            abi3_interface_library = symlinked[0]
    else:
        abi3_libraries = []
        logger.info("No abi3 python libraries found.")

    additional_dlls = _symlink_libraries(rctx, logger, info["additional_dlls"], None)

    build_bazel = _TOOLCHAIN_IMPL_TEMPLATE.format(
        major = info["major"],
        minor = info["minor"],
        micro = info["micro"],
        abi_flags = info["abi_flags"],
        os = "@platforms//os:{}".format(repo_utils.get_platforms_os_name(rctx)),
        implementation_name = info["implementation_name"],
        interpreter_path = repo_utils.norm_path(interpreter_path),
        interface_library = repr(interface_library),
        libraries = repr(libraries),
        defines = repr(info["defines"]),
        abi3_interface_library = repr(abi3_interface_library),
        abi3_libraries = repr(abi3_libraries),
        additional_dlls = repr(additional_dlls),
    )
    logger.debug(lambda: "BUILD.bazel\n{}".format(build_bazel))

    rctx.file("WORKSPACE", "")
    rctx.file("MODULE.bazel", "")
    rctx.file("REPO.bazel", "")
    rctx.file("BUILD.bazel", build_bazel)

# Public python attributes.
python_attrs = {
    "interpreter_path": attr.string(
        doc = """
An absolute path or program name on the `PATH` env var.

Values with slashes are assumed to be the path to a program. Otherwise, it is
treated as something to search for on `PATH`

Note that, when a plain program name is used, the path to the interpreter is
resolved at repository evalution time, not runtime of any resulting binaries.
""",
    ),
    "on_failure": attr.string(
        default = "skip",
        values = ["fail", "skip", "warn"],
        doc = """
How to handle errors when trying to automatically determine settings.

* `skip` will silently skip creating a runtime. Instead, a non-functional
  runtime will be generated and marked as incompatible so it cannot be used.
  This is best if a local runtime is known not to work or be available
  in certain cases and that's OK. e.g., one use windows paths when there
  are people running on linux.
* `warn` will print a warning message. This is useful when you expect
  a runtime to be available, but are OK with it missing and falling back
  to some other runtime.
* `fail` will result in a failure. This is only recommended if you must
  ensure the runtime is available.
""",
    ),
    "_get_runtime_info": attr.label(
        allow_single_file = True,
        default = Label("//bazel/repo_rules:get_local_runtime_info.py"),
    ),
}

# Environment variables that are used by python_configure.
PYTHON_ENV_VARS = _ENV_VARS + [
    # When provided, the value of this environment variable is used to
    # determine whether to use the remote python configuration repository.
    _TENSORSTORE_PYTHON_CONFIG_REPO,
]

local_python_runtime = repository_rule(
    implementation = _local_python_repo_impl,
    local = True,
    configure = True,
    attrs = python_attrs | {
        "_rule_name": attr.string(default = "local_python_runtime"),
    },
    environ = PYTHON_ENV_VARS,
    doc = """\
Detects and configures the local Python.

Add the following to your WORKSPACE FILE:

```python
local_python_runtime(name = "local_config_python")
```

Args:
  name: A unique name for this workspace rule.
""",
)
